import os
import random
import shutil
import time
from calendar import c

Methods = ["GA_FT_book", "GA", "KL", "CL", "RL", "origin"]
# Methods = ["CL"]
Methods = ["GA", "GA_FT_book", "CL"]
Params = {
    "GA": "--unlearn.lr 5e-6",
    "KL": "--unlearn.lr 1e-5 --unlearn.KL.gamma 0.1",
    "GA_FT_book": "--unlearn.lr 5e-6 --unlearn.GA+FT.gamma 2.0",
    "CL": "--unlearn.lr 1e-5",
    "RL": "--unlearn.lr 1e-5",
    "origin": "",
}
Epochs = {
    "GA": [1, 3],
    "GA_FT_book": [1, 3],
    "CL": [1, 3, 5, 7],
}
sophia = "--unlearn.sophia 1 --unlearn.sophia_params.betas_low 0.9 --unlearn.sophia_params.betas_high 0.95 --unlearn.sophia_params.rho 0.03"


def gen_commands_GA_FT_PILE():
    commands = []
    lrs = [5e-6]
    gammas = [0.01, 0.1, 0.5, 1.0, 5.0, 10.0, 20.0, 50.0, 100.0]
    epochs = [5]
    for lr in lrs:
        for gamma in gammas:
            for epoch in epochs:
                command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/Detoxify/GA_FT_Pile.json --unlearn.lr {lr} --unlearn.GA+FT.gamma {gamma} --unlearn.num_epochs {epoch} --logger.json.root files/results/unlearn_toxic/GA_FT_Pile/opt_1b_lr_{lr}_gamma_{gamma}_epoch_{epoch}/"
                commands.append(command)
    return commands


def gen_commands():
    commands = []
    Methods = ["GA"]
    lrs = [1e-4]
    # lrs = [1e-5]
    gammas = [1.0, 0.1, 0.01, 0.001]
    # gammas = [0.1]
    for method in Methods:
        if method == "KL":
            for gamma in gammas:
                for lr in lrs:
                    command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/Detoxify/{method}.json --unlearn.lr {lr} --unlearn.KL.gamma {gamma} --logger.json.root files/results/unlearn_toxic/{method}/opt_1b_lr_{lr}_gamma_{gamma}/"
                    commands.append(command)
        elif method == "GA_FT_book":
            for gamma in gammas:
                for lr in lrs:
                    command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/Detoxify/GA_FT_book.json --unlearn.lr {lr} --unlearn.GA+FT.gamma {gamma} --logger.json.root files/results/unlearn_toxic/{method}/opt_1b_lr_{lr}_gamma_{gamma}/"
                    commands.append(command)
        else:
            for lr in lrs:
                command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/Detoxify/{method}.json --unlearn.lr {lr} --logger.json.root files/results/unlearn_toxic/{method}/opt_1b_lr_{lr}/"
                commands.append(command)
    # for lr in lrs:
    #     # command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/Detoxify/GA.json --unlearn.lr {lr} --logger.json.root files/results/unlearn_toxic/opt_1b_GA_lr_200_{lr}/"
    #     # commands.append(command)
    #     for gamma in gammas:
    #         command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/Detoxify/GA_FT_book.json --unlearn.lr {lr} --unlearn.GA+FT.gamma {gamma} --logger.json.root files/results/unlearn_toxic/opt_1b_GA_FT_lr_{lr}_gamma_{gamma}/"
    #         commands.append(command)
    return commands


def gen_commands_CL_joint():
    commands = []

    score_types = [
        "hessianfree_score_joint",
        "gradient_score_joint",
    ]
    ratios = [0.2, 0.4, 0.6, 0.8]
    for score_type in score_types:
        for ratio in ratios:
            command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/Detoxify/CL.json --unlearn.p {ratio} --unlearn.q {ratio} --logger.json.root files/results/unlearn_toxic/CL/opt_1b_p_{ratio}_q_{ratio}_{score_type}_epoch5/ --unlearn.mask_path files/results/unlearn_toxic/opt1.3b/mask/{score_type}/with_{ratio}_{ratio}.pt --unlearn.num_epochs 5"
            commands.append(command)
    # score_types = ["mag_hessianfree_joint", "gradient_joint"]
    # ps = [0.1, 0.2, 0.4, 0.5]
    # for score_type in score_types:
    #     for p in ps:
    #         command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/Detoxify/CL.json --unlearn.p {p} --unlearn.q {p} --logger.json.root files/results/unlearn_toxic/CL/opt_1b_p_{p}_q_{p}_{score_type}/ --unlearn.mask_path files/results/unlearn_toxic/opt1.3b/mask/{score_type}/with_{p}_{p}.pt"
    #         commands.append(command)
    return commands


def gen_commands_GA_mask():
    ratios = [0.4, 0.2, 0.6]
    lr = [1e-5, 5e-5]
    commands = []
    for ratio in ratios:
        for l in lr:
            command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/Detoxify/GA.json --unlearn.lr {l} --logger.json.root files/results/unlearn_toxic/opt_1b_ratio_hessianfree_mask_{ratio}_{l}/ --unlearn.mask_path files/results/unlearn_toxic/opt1.3b/mask/hessianfree/with_{ratio}.pt"
            commands.append(command)

    return commands


def gen_commands_ratio_tofu():
    ratios = [0.2, 0.4, 0.6, 0.8, 1.0]
    Methods = ["CL_llama"]
    keys = ["snip_forget", "gradient", "random"]
    commands = []
    for method in Methods:
        for key in keys:
            for ratio in ratios:
                if ratio == 1.0:
                    if key == "snip_forget":
                        command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/Tofu/{method}.json --logger.json.root files/results/unlearn_tofu/{method}/llama7b_ratio_{ratio}/"
                    else:
                        continue
                else:
                    command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/Tofu/{method}.json --logger.json.root files/results/unlearn_tofu/{method}/llama7b_ratio_{key}_mask_{ratio}/ --unlearn.mask_path files/results/unlearn_tofu/llama7b/mask/{key}/with_{ratio}.pt"
                commands.append(command)
    return commands


def gen_commands_ratio_toxic():
    ratios = [0.2, 0.4, 0.6, 0.8]
    Methods = ["CL_llama"]
    keys = ["wanda"]
    commands = []
    for method in Methods:
        for key in keys:
            for ratio in ratios:
                if ratio == 1.0:
                    if key == "snip_forget":
                        command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/Detoxify/{method}.json --logger.json.root files/results/unlearn_Detoxify/{method}/llama7b_ratio_{ratio}/"
                    else:
                        continue
                else:
                    command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/Detoxify/{method}.json --logger.json.root files/results/unlearn_Detoxify/{method}/llama7b_ratio_{key}_mask_{ratio}/ --unlearn.mask_path files/results/unlearn_Detoxify/llama7b/mask/{key}/with_{ratio}.pt"
                commands.append(command)
    return commands


def gen_commands_sophia():
    Methods = ["GA_FT_book"]
    commands = []
    # for method in Methods:
    # for epoch in Epochs[method]:
    #     command = (
    #         f"python src/exec/unlearn_model.py --config-file configs/unlearn/Detoxify/{method}.json --logger.json.root files/results/unlearn_toxic/{method}/sophia_epoch_{epoch} --unlearn.num_epochs {epoch}"
    #         + " "
    #         + Params[method]
    #         + " "
    #         + sophia
    #     )
    #     commands.append(command)
    gammas = [1.0, 0.1, 0.5, 0.01, 0.05]
    for gamma in gammas:
        command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/Detoxify/GA_FT_book.json --logger.json.root files/results/unlearn_toxic/GA_FT_book/sophia_gamma_{gamma} --unlearn.GA+FT.gamma {gamma} --unlearn.sophia 1 --unlearn.sophia_params.betas_low 0.9 --unlearn.sophia_params.betas_high 0.95 --unlearn.sophia_params.rho 0.03 --unlearn.lr 5e-6"
        commands.append(command)
    return commands


def gen_commands_ratios():
    Methods = ["CL"]
    # keys = ["snip_forget", "snip_retain"]
    keys = ["snip_forget"]
    ratios = [0.2, 0.4, 0.6, 0.8]
    commands = []
    for key in keys:
        for method in Methods:
            if method == "origin":
                command = (
                    f"python src/exec/unlearn_model.py --config-file configs/unlearn/Detoxify/{method}.json --logger.json.root files/results/unlearn_toxic/{method}"
                    + " "
                    + Params[method]
                )
                commands.append(command)
            else:
                for ratio in ratios:
                    if ratio == 1.0 and key == "random":
                        command = (
                            f"python src/exec/unlearn_model.py --config-file configs/unlearn/Detoxify/{method}.json --logger.json.root files/results/unlearn_toxic/{method}/opt_1b_ratio_{ratio}/"
                            + " "
                            + Params[method]
                        )
                    else:
                        command = (
                            f"python src/exec/unlearn_model.py --config-file configs/unlearn/Detoxify/{method}.json  --logger.json.root files/results/unlearn_toxic/{method}/opt_1b_ratio_{key}_mask_{ratio}/ --unlearn.mask_path files/results/unlearn_toxic/opt1.3b/mask/{key}/with_{ratio}.pt"
                            + " "
                            + Params[method]
                        )
                        # command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/Detoxify/{method}.json --logger.json.root files/results/unlearn_toxic/{method}/opt_1b_ratio_{key}_mask_{ratio}_epoch5/ --unlearn.mask_path files/results/unlearn_toxic/opt1.3b/mask/{key}/with_{ratio}.pt --unlearn.num_epochs 5"
                    commands.append(command)
    return commands


def gen_commands_CL_FT_sophia():
    commands = []
    gammas = [0.9, 0.5, 0.7, 0.3, 0.1]
    for gamma in gammas:
        command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/Detoxify/CL+FT.json --unlearn.CL+FT.gamma {gamma} --logger.json.root files/results/unlearn_toxic/CL_FT/opt_1b_gamma_{gamma}_sophia/ --unlearn.sophia 1 --unlearn.sophia_params.betas_low 0.9 --unlearn.sophia_params.betas_high 0.95 --unlearn.sophia_params.rho 0.03"
        commands.append(command)
    return commands


def gen_commands_CL_FT_KL():
    methods = ["CL+FT", "CL+KL"]
    commands = []
    gammas = [0.001, 0.01, 0.1, 1]
    for method in methods:
        if method == "CL+FT":
            gammas = [0.01, 0.1, 1]
        for gamma in gammas:
            gammas = [0.01, 0.001, 1e-4]
            command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/Detoxify/{method}.json --unlearn.{method}.gamma {gamma} --logger.json.root files/results/unlearn_toxic/{method}/opt_1b_gamma_{gamma}/"
            commands.append(command)
    return commands


def gen_commands_new_eval():
    commands = []
    for method in Methods:
        for epoch in Epochs[method]:
            command = (
                f"python src/exec/unlearn_model.py --config-file configs/unlearn/Detoxify/{method}.json --logger.json.root files/results/unlearn_toxic/{method}/sophia_epoch_{epoch} --unlearn.num_epochs {epoch}"
                + " "
                + Params[method]
            )
            commands.append(command)
    return commands


def gen_commands_GA():
    commands = []
    lrs = [1e-4, 1e-5, 5e-5, 1e-6, 5e-6]
    epochs = [1, 3, 5, 10]
    for lr in lrs:
        for epoch in epochs:
            command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/Detoxify/GA.json --unlearn.lr {lr} --unlearn.num_epochs {epoch} --logger.json.root files/results/unlearn_toxic/GA/opt_1b_lr_{lr}_epoch_{epoch}/"
            commands.append(command)
    return commands


def gen_commands_snip_joint():
    commands = []
    ps = [0.2, 0.4, 0.5, 0.6, 0.8]
    for p in ps:
        q = p
        command = (
            f"python src/exec/unlearn_model.py --config-file configs/unlearn/Detoxify/CL.json --unlearn.p {p} --unlearn.q {q} --logger.json.root files/results/unlearn_toxic/CL/opt_1b_p_{p}_q_{q}/ --unlearn.mask_path files/results/unlearn_toxic/opt1.3b/mask/snip_joint/with_{p}_{q}.pt"
            + " "
            + Params["CL"]
        )
        commands.append(command)
    return commands


def gen_commands_GA_FT():
    commands = []
    # lrs = [5e-6, 1e-5]
    # gammas = [0.1, 0.01, 1.0]
    epochs = [5]
    gammas = [0.1, 1.0, 5.0, 0.5]
    params = {
        5e-6: [2, 4],
        1e-5: [0.5, 0.02, 0.05],
        5e-5: [0.005, 0.002, 0.001],
    }
    # for lr in params.keys():
    #     for gamma in params[lr]:
    #         command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/Detoxify/GA_FT_book.json --unlearn.lr {lr} --unlearn.GA+FT.gamma {gamma} --logger.json.root files/results/unlearn_toxic/GA_FT/opt_1b_lr_{lr}_gamma_{gamma}/"
    #         commands.append(command)
    for gamma in gammas:
        command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/Detoxify/GA_FT_book.json --unlearn.lr 5e-6 --unlearn.GA+FT.gamma {gamma} --logger.json.root files/results/unlearn_toxic/GA_FT_book/opt_1b_lr_5e-6_gamma_{gamma}/"
        commands.append(command)
    return commands


def gen_commands_CL():
    epochs = [1, 3, 5, 7]
    commands = []
    for epoch in epochs:
        command = (
            f"python src/exec/unlearn_model.py --config-file configs/unlearn/Detoxify/CL.json --unlearn.num_epochs {epoch} --logger.json.root files/results/unlearn_toxic/CL/opt_1b_epoch_{epoch}/"
            + " "
            + Params["CL"]
        )
        commands.append(command)
    return commands




def run_commands(
    commands,
    call=False,
    gpus=0,
    dir="commands",
    shuffle=False,
    delay=0.5,
    hpcc=False,
    if_llama=False,
    per_gpu=2,
):
    if os.path.exists(dir):
        shutil.rmtree(dir)
    if shuffle:
        random.shuffle(commands)
    os.makedirs(dir, exist_ok=True)
    if hpcc:
        with open(os.path.join("./scripts/slurm", "exp.sb"), "r") as file:
            prefix = file.read()
    fout = open("stop_{}.sh".format(dir), "w")
    print("kill $(ps aux|grep 'bash " + dir + "'|awk '{print $2}')", file=fout)
    fout.close()

    total = len(commands)
    print("total: {}".format(total))

    n_gpu = len(gpus)

    if if_llama:
        for i in range(per_gpu):
            i_commands = commands[i::per_gpu]
            gpu_local = gpus[i::per_gpu]
            prefix = "CUDA_VISIBLE_DEVICES={} ".format(",".join(map(str, gpu_local)))
            sh_path = os.path.join(dir, "run{}.sb".format(i))
            fout = open(sh_path, "w")
            for com in i_commands:
                print(prefix + com, file=fout)
            fout.close()
            if call:
                os.system("bash {}&".format(sh_path))
            time.sleep(delay)

    else:
        for i, gpu in enumerate(gpus):
            i_commands = commands[i::n_gpu]
            if not hpcc:
                prefix = "CUDA_VISIBLE_DEVICES={} ".format(gpu)
                sh_path = os.path.join(dir, "run{}.sh".format(i))
            else:
                sh_path = os.path.join(dir, "run{}.sb".format(i))
            fout = open(sh_path, "w")
            for com in i_commands:
                print(prefix + com, file=fout)
            fout.close()
            if call:
                if not hpcc:
                    os.system("bash {}&".format(sh_path))
                else:
                    os.system("sbatch {}&".format(sh_path))
                time.sleep(delay)


def gen_commands_unlearn_copyright():
    commands = []
    Methods = ["CL"]
    lrs = [1e-5, 5e-5, 5e-6]

    for method in Methods:
        if method == "GA":
            epochs = [1, 3, 5]
        else:
            epochs = [1, 3, 5, 7, 10]
        for lr in lrs:
            for epoch in epochs:
                command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/Copyright/{method}.json --unlearn.lr {lr} --unlearn.num_epochs {epoch} --logger.json.root files/results/unlearn_copyright/{method}/opt_1b_lr_{lr}_epoch_{epoch}/"
                commands.append(command)
    return commands


def gen_commands_copyright_ratios():
    commands = []
    Methods = ["CL"]
    # keys = ["gradient", "maghessianfree", "snip_forget", "snip_retain", "random"]
    keys = ["snip_advanced_new"]
    ratios = [0.2, 0.4, 0.6, 0.8]
    for key in keys:
        for method in Methods:
            if method == "origin":
                command = (
                    f"python src/exec/unlearn_model.py --config-file configs/unlearn/Copyright/{method}.json --logger.json.root files/results/unlearn_copyright/{method}"
                    + " "
                    + Params[method]
                )
                commands.append(command)
            else:
                for ratio in ratios:
                    if ratio == 1.0:
                        if method == "gradient":
                            command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/Copyright/{method}.json --logger.json.root files/results/unlearn_copyright/{method}/opt_1b_ratio_{ratio}/"
                        else:
                            continue
                    else:
                        # command = (
                        #     f"python src/exec/unlearn_model.py --config-file configs/unlearn/Copyright/{method}.json  --logger.json.root files/results/unlearn_copyright/{method}/opt_1b_ratio_{key}_mask_{ratio}/ --unlearn.mask_path files/results/unlearn_copyright/opt1.3b/mask/{key}/with_{ratio}.pt"
                        #     + " "
                        #     + Params[method]
                        # )
                        command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/Copyright/{method}.json --logger.json.root files/results/unlearn_copyright/{method}/opt_1b_ratio_{key}_mask_{ratio}/ --unlearn.mask_path files/results/unlearn_copyright/opt1.3b/mask/{key}/with_{ratio}.pt"
                    commands.append(command)
    return commands


def gen_commands_ffn():
    tasks = ["Copyright", "Detoxify"]
    methods = ["CL"]
    keys = ["FFN"]
    ratios = [0.0]
    commands = []
    for task in tasks:
        for key in keys:
            for method in methods:
                for ratio in ratios:
                    if task == "Detoxify":
                        command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/{task}/{method}.json --logger.json.root files/results/unlearn_toxic/{method}/opt_1b_ratio_{key}_mask_{ratio}/ --unlearn.mask_path files/results/unlearn_toxic/opt1.3b/mask/{key}/with_{ratio}.pt"
                    elif task == "Copyright":
                        command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/{task}/{method}.json --logger.json.root files/results/unlearn_copyright/{method}/opt_1b_ratio_{key}_mask_{ratio}/ --unlearn.mask_path files/results/unlearn_copyright/opt1.3b/mask/{key}/with_{ratio}.pt"
                    commands.append(command)
    return commands


def gen_commands_unlearn_wmdp():
    Method = ["GA_FT_zephyr_bio"]
    gammas = [0.1, 1.0, 5.0, 10.0, 200.0, 500.0]
    lrs = [1e-6, 5e-6]
    commands = []
    for method in Method:
        for lr in lrs:
            for gamma in gammas:
                if lr == 1e-6:
                    max_step = 500
                else:
                    max_step = 50

                command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/wmdp/{method}.json --unlearn.lr {lr} --unlearn.GA+FT.gamma {gamma} --unlearn.max_steps {max_step} --logger.json.root files/results/unlearn_wmdp/{method}/opt_1b_lr_{lr}_gamma_{gamma}_max_step_{max_step}/"
                commands.append(command)
    return commands


def gen_commands_unlearn_PII():
    Method = ["GA_FT_book", "CL"]
    lrs = [1e-6, 5e-6, 1e-5]
    gamma = [0.1, 1.0, 5.0, 0.0, 0.01]
    commands = []
    for method in Method:
        if method == "CL":
            command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/PII/{method}.json --unlearn.lr 1e-5 --logger.json.root files/results/unlearn_PII/{method}/opt_1b_lr_1e-5/"
            commands.append(command)
        else:
            for lr in lrs:
                for g in gamma:
                    command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/PII/{method}.json --unlearn.lr {lr} --unlearn.GA+FT.gamma {g} --logger.json.root files/results/unlearn_PII/{method}/opt_1b_lr_{lr}_gamma_{g}/"
                    commands.append(command)
    return commands


def gen_commands_unlearn_PII_ratio():
    # Method = ["GA_FT_book"]
    Method = ["CL"]
    lrs = [2e-5, 3e-5, 4e-5, 5e-5]
    epochs = 5
    # keys = ["snip_forget","random","gradient"]
    keys = ["snip_forget"]
    # keys =["snip_advanced","snip_forget"]
    ratios = [0.2]
    gammas = [0.1, 1.0, 5.0, 0.0]
    commands = []
    for method in Method:
        for key in keys:
            # if key == "snip_advanced":
            #     ratios = [0.6]
            # else:
            #     ratios = [0.8]
            for ratio in ratios:
                for lr in lrs:
                    # command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/PII/{method}.json --unlearn.lr {lr} --unlearn.num_epochs {epochs} --logger.json.root files/results/unlearn_PII/{method}/opt_1b_ratio_{key}_mask_{ratio}_lr_{lr} --unlearn.mask_path files/results/unlearn_PII/opt1.3b/mask/{key}/with_{ratio}.pt "
                    # commands.append(command)
                    command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/PII/{method}.json --unlearn.lr {lr} --unlearn.num_epochs {epochs} --logger.json.root files/results/unlearn_PII/{method}/opt_1b_ratio_{key}_lr_{lr}"
                    commands.append(command)
                # command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/PII/{method}.json --unlearn.num_epochs {epochs} --logger.json.root files/results/unlearn_PII/{method}/opt_1b_ratio_{key}_mask_{ratio}_lr --unlearn.mask_path files/results/unlearn_PII/opt1.3b/mask/{key}/with_{ratio}.pt "
                # commands.append(command)
                # for gamma in gammas:
                #     command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/PII/{method}.json --unlearn.lr 5e-6 --unlearn.GA+FT.gamma {gamma} --unlearn.num_epochs {epochs} --logger.json.root files/results/unlearn_PII/{method}/opt_1b_ratio_{key}_mask_{ratio}_gamma_{gamma}/ --unlearn.mask_path files/results/unlearn_PII/opt1.3b/mask/{key}/with_{ratio}.pt"
                #     commands.append(command)
    return commands


def gen_commands_unlearn_PII_ratio_llama():
    # Method = ["GA_FT_book"]
    Method = ["CL_llama"]
    lrs = [1e-5, 2e-5, 3e-5, 5e-5]
    epochs = 5
    keys = ["wanda"]
    # keys = ["snip_forget"]
    # keys =["snip_advanced","snip_forget"]
    ratios = [1.0]
    gammas = [0.1, 1.0, 5.0, 0.0]
    commands = []
    for method in Method:
        for key in keys:
            # if key == "snip_advanced":
            #     ratios = [0.6]
            # else:
            #     ratios = [0.8]
            for ratio in ratios:
                for lr in lrs:
                    command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/PII/{method}.json --unlearn.lr {lr} --unlearn.num_epochs {epochs} --logger.json.root files/results/unlearn_PII/{method}/llama7b_ratio_{key}_mask_{ratio}_lr_{lr} --unlearn.mask_path files/results/unlearn_PII/llama7b/mask/{key}/with_{ratio}.pt "
                    commands.append(command)
                    # command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/PII/{method}.json --unlearn.lr {lr} --unlearn.num_epochs {epochs} --logger.json.root files/results/unlearn_PII/{method}/opt_1b_ratio_{key}_lr_{lr}"
                    # commands.append(command)
                # command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/PII/{method}.json --unlearn.num_epochs {epochs} --logger.json.root files/results/unlearn_PII/{method}/llama7b_ratio_{key}_mask_{ratio}_lr --unlearn.mask_path files/results/unlearn_PII//mask/{key}/with_{ratio}.pt "
                # commands.append(command)
                # for gamma in gammas:
                #     command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/PII/{method}.json --unlearn.lr 5e-6 --unlearn.GA+FT.gamma {gamma} --unlearn.num_epochs {epochs} --logger.json.root files/results/unlearn_PII/{method}/opt_1b_ratio_{key}_mask_{ratio}_gamma_{gamma}/ --unlearn.mask_path files/results/unlearn_PII/opt1.3b/mask/{key}/with_{ratio}.pt"
                #     commands.append(command)
    return commands


def gen_commands_unlearn_lora():
    lrs = [1e-4, 5e-5, 1e-3]
    commands = []
    for lr in lrs:
        command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/Detoxify/CL.json  --logger.json.root files/results/unlearn_toxic/CL/opt_1b_ratio_lora_lr{lr} --unlearn.use_lora 1 --unlearn.lr {lr}"
        commands.append(command)
    return commands


def gen_commands_unlearn_lora_llama():
    lrs = [1e-4, 5e-5, 1e-3]
    commands = []
    for lr in lrs:
        command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/Detoxify/CL_llama.json  --logger.json.root files/results/unlearn_Detoxify/CL/llama7b_lora_lr{lr} --unlearn.use_lora 1 --unlearn.lr {lr}"
        commands.append(command)
    return commands

# def gen_commands_CL_FT_ratios():
#     commands = []
#     gammas = [0.1,1,5]
#     # keys = ["snip_forget", "gradient", "random"]
#     keys=["snip_forget_CL_FT","snip_forget_FT"]
#     ratios = [0.2, 0.4, 0.6, 0.8]
    
#     for key in keys:
#         for gamma in gammas:
#             for ratio in ratios:
#                 newkey = f"{key}_{gamma}"
#                 command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/Tofu/CL_FT_llama.json --unlearn.CL+FT.gamma {gamma} --logger.json.root files/results/unlearn_tofu/CL_FT/llama7b_gamma_{gamma}_{newkey}_mask_ratio_{ratio} --unlearn.mask_path files/results/unlearn_tofu/llama7b/mask/{newkey}/with_{ratio}.pt"
#                 commands.append(command)
#     return commands
def gen_commands_CL_FT_ratios():
    commands = []
    gammas = [1,5]
    lrs = [ 5e-5,1e-5]
    # keys = ["snip_forget", "gradient", "random"]
    keys=["weight"]
    ratios = [0.2, 0.4, 0.6, 0.8]
    for lr in lrs:
        for key in keys:
            for gamma in gammas:
                for ratio in ratios:
                    command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/Tofu/CL_FT_llama.json --unlearn.CL+FT.gamma {gamma} --unlearn.lr {lr} --logger.json.root files/results/unlearn_tofu/CL_FT/llama7b_gamma_{gamma}_{key}_mask_ratio_{ratio}_lr_{lr} --unlearn.mask_path files/results/unlearn_tofu/llama7b/mask/{key}/with_{ratio}.pt"
                    commands.append(command)
    return commands
def gen_commands_CL_ratios():
    commands = []
    gammas = [1,5]
    # keys = ["snip_forget", "gradient", "random"]
    keys=["union"]
    ratios = [0.2, 0.4, 0.6, 0.8, 1.0]
    
    for key in keys:
        for ratio in ratios:
            if ratio == 1.0:
                if key != "snip_forget":
                    continue
            command = f"python src/exec/unlearn_model.py --config-file configs/unlearn/Tofu/CL_llama.json  --logger.json.root files/results/unlearn_tofu/CL/llama7b_{key}_mask_ratio_{ratio} --unlearn.mask_path files/results/unlearn_tofu/llama7b/mask/{key}/with_{ratio}.pt"
            commands.append(command)
    return commands


if __name__ == "__main__":
    # commands = gen_commands()
    # commands = gen_commands_unlearn_copyright()
    commands = gen_commands_ratios()
    # commands = gen_commands_ratios()
    # commands = gen_commands_snip_joint()
    # commands = gen_commands_CL_FT()
    # commands = gen_commands_GA_FT()
    # commands = gen_commands_GA_mask()
    # commands = gen_commands_unlearn_wmdp()
    # commands = gen_commands_unlearn_copyright()
    # commands = gen_commands_CL_FT_sophia()
    # commands = gen_commands_CL_FT_KL()
    # commands1 = gen_commands_copyright_ratios()
    # commands = gen_commands_unlearn_lora()
    # commands = commands1 + commands2
    # commands = gen_commands_unlearn_PII()
    # commands = gen_commands_unlearn_PII_ratio_llama()
    # commands = gen_commands_ratio_toxic()
    # commands = gen_commands_unlearn_lora_llama()
    commands = gen_commands_CL_FT_ratios()
    # commands = gen_commands_CL_ratios()
    num_commands = len(commands)
    run_commands(
        commands,
        call=True,
        gpus=[0, 1, 2, 3, 4, 5, 6, 7],
        dir="commands",
        shuffle=False,
        delay=0.5,
        hpcc=False,
        if_llama=True,
        per_gpu=4,
    )
    os.system("bash {}&".format("clean_unlearn_checkpoint.sh"))
